"""train process"""

import argparse
import os
import random
import time

import numpy as np
import torch
from torch import nn
from torch.utils import data as Data
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
)
import torch.multiprocessing as mp
import functools
import pathlib

from src import (
    calculate_l2_error,
    create_datasets,
    load_yaml_config,
    pad_collate,
    tokmak_model,
)


seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def parse_args():
    """Parse input args"""
    parser = argparse.ArgumentParser(description="tokmak train")
    parser.add_argument(
        "--config_file_path", type=str, default="./configs/tokamak.yaml"
    )
    parser.add_argument(
        "--device_target",
        type=str,
        default="Ascend",
        choices=["GPU", "Ascend"],
        help="The target device to run, support 'Ascend', 'GPU'",
    )
    parser.add_argument(
        "--device_id", type=int, default=0, help="ID of the target device"
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="GRAPH",
        choices=["GRAPH", "PYNATIVE"],
        help="Running in GRAPH_MODE OR PYNATIVE_MODE",
    )
    parser.add_argument(
        "--save_graphs",
        type=bool,
        default=False,
        choices=[True, False],
        help="Whether to save intermediate compilation graphs",
    )
    parser.add_argument("--save_graphs_path", type=str, default="./graphs")

    input_args = parser.parse_args()
    return input_args


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def train_process(
    model,
    rank,
    cur_steps,
    optimizer,
    scheduler,
    train_loader,
    out_channels,
    lossfn_config,
):
    """
    Args:
        ...
    Returns:
        ...
    """
    model.train()
    torch.cuda.set_device(rank)
    ddp_loss = torch.zeros(2).cuda()
    for X, valid_len, valid_channels, h5_files in train_loader:
        optimizer.zero_grad(set_to_none=True)
        X = X.cuda()
        valid_len = valid_len.cuda()
        valid_channels = valid_channels.cuda()
        device = X.device
        dtype = X.dtype
        batch_size = X.shape[0]
        cur_steps += batch_size
        enc_inputs = X[:, :, :-out_channels]
        label = X[:, :, -out_channels:]
        valid_channels = valid_channels[:, -out_channels:]
        padded_values = torch.zeros(
            label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
        )
        dec_padded_inputs = torch.cat((padded_values, label), 1)
        dec_inputs = dec_padded_inputs[:, :-1, :]
        l2_loss = calculate_l2_error(
            model,
            enc_inputs,
            dec_inputs,
            cur_steps,
            label,
            valid_len,
            valid_channels,
            lossfn_config,
        )
        l2_loss.backward()
        optimizer.step()
        scheduler.step()
        ddp_loss[0] += l2_loss.detach().item()
        ddp_loss[1] += batch_size
    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    return ddp_loss[0] / ddp_loss[1], cur_steps


@torch.no_grad()
def eval_process(model, rank, val_loader, out_channels, lossfn_config):
    # cur_steps is
    """ """
    model.eval()
    # epoch_eval_loss = 0
    torch.cuda.set_device(rank)
    ddp_loss = torch.zeros(2).to(rank)
    for X, valid_len, valid_channels, h5_files in val_loader:
        X = X.cuda()
        valid_len = valid_len.cuda()
        valid_channels = valid_channels.cuda()
        device = X.device
        dtype = X.dtype
        batch_size = X.shape[0]
        enc_inputs = X[:, :, :-out_channels]
        label = X[:, :, -out_channels:]
        valid_channels = valid_channels[:, -out_channels:]
        padded_values = torch.zeros(
            label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
        )
        dec_padded_inputs = torch.cat((padded_values, label), 1)
        dec_inputs = dec_padded_inputs[:, :-1, :]
        # cur_steps is meaningless when model evaluation.
        cur_steps = 0
        l2_loss = calculate_l2_error(
            model,
            enc_inputs,
            dec_inputs,
            cur_steps,
            label,
            valid_len,
            valid_channels,
            lossfn_config,
        )
        ddp_loss[0] += l2_loss.detach().item()
        ddp_loss[1] += batch_size
    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    return ddp_loss[0] / ddp_loss[1]


def fsdp_train(rank, world_size, config):
    """fsdp train network"""
    setup(rank, world_size)
    torch.cuda.set_device(rank)
    train_paras = config["data"]["train"]
    val_paras = config["data"]["validation"]
    loss_paras = config["loss"]["train"]
    # For next improve, please not rewrite this.
    lossfn_config = {}
    lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]

    # create dataset & dataloader
    train_set, val_set = create_datasets(config, is_debug=train_paras['is_debug'])
    train_loader = Data.DataLoader(
        train_set,
        batch_size=train_paras["batch_size"],
        num_workers=train_paras["num_workers"],
        collate_fn=pad_collate,
    )
    train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)
    val_loader = Data.DataLoader(
        val_set,
        batch_size=val_paras["batch_size"],
        num_workers=val_paras["num_workers"],
        collate_fn=pad_collate,
    )
    val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)

    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    # define models and optimizers
    model_params = config["model"]
    optim_params = config["optimizer"]
    model = tokmak_model(
        in_channels=model_params["in_channels"],
        hidden_size=model_params["hidden_size"],
        num_layers=model_params["num_layers"],
        dropout_rate=model_params["dropout_rate"],
        out_channels=model_params["out_channels"],
        noise_ratio=model_params["noise_ratio"],
    )
    model.cuda()
    model = FSDP(model)
    out_channels = model_params["out_channels"]

    # define optimizer & scheduler
    optimizer_fn = torch.optim.SGD
    optimizer = optimizer_fn(
        model.parameters(),
        lr=float(optim_params["lr"]),
        weight_decay=optim_params["weight_decay"],
    )
    scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
    scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
    scheduler = scheduler_fn(
        optimizer,
        max_lr=float(optim_params["lr"]),
        steps_per_epoch=train_steps_per_epoch,
        epochs=train_paras["epochs"],
    )

    epochs = config["data"]["train"]["epochs"]
    cur_steps = 0
    for epoch in range(1, 1 + epochs):
        # train
        if rank == 0:
            local_time_beg = time.time()
        step_train_loss, cur_steps = train_process(
            model,
            rank,
            cur_steps,
            optimizer,
            scheduler,
            train_loader,
            out_channels,
            lossfn_config,
        )
        if rank == 0:
            # epoch_train_loss / train_steps_per_epoch
            local_time_end = time.time()
            epoch_seconds = (local_time_end - local_time_beg) * 1000
            step_seconds = epoch_seconds / train_steps_per_epoch
            # step_train_loss differ from the real train_loss.
            print(
                f"epoch: {epoch} train loss: {step_train_loss} "
                f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f}s"
            )

        if epoch % config["summary"]["eval_interval_epochs"] == 0:
            if rank == 0:
                eval_time_start = time.time()
            step_val_loss = eval_process(
                model, rank, val_loader, out_channels, lossfn_config
            )
            if rank == 0:
                epoch_val_seconds = (time.time() - eval_time_start) * 1000
                step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
                print(
                    f"epoch: {epoch} val loss: {step_val_loss} "
                    f"evaluation time: {time.time() - eval_time_start:5.3f}s step time: {step_val_seconds:5.3f}ms"
                )
                if config["summary"]["save_ckpt"]:
                    tempdir = config["summary"]["ckpt_dir"]
                    os.makedirs(tempdir, exist_ok=True)
                    torch.save(
                        {"epoch": epoch, "model_state": model.state_dict()},
                        os.path.join(
                            tempdir,
                            f"fsdp-{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
                        ),
                    )


def train():
    # load configurations
    config = load_yaml_config(args.config_file_path)
    world_size = torch.cuda.device_count()
    mp.spawn(
        fsdp_train,
        args=(
            world_size,
            config,
        ),
        nprocs=world_size,
        join=True,
    )


if __name__ == "__main__":
    from src import log_config

    log_config("./logs", "tokmak")
    print("pid:", os.getpid())
    args = parse_args()
    print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
    #    use_ascend = context.get_context(attr_key='device_target') == "Ascend"
    train()
